{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a8a56918-88b7-4d44-8653-778a9305d21d",
   "metadata": {
    "tags": []
   },
   "source": [
    "## ragas评估"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1956cfa-77fe-4613-822d-0943f3fd8b3b",
   "metadata": {},
   "source": [
    "### rag pipline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "fa084178-bb13-4b6f-b790-ac0606eb2a74",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "from model import RagEmbedding, RagLLM, QwenLLM\n",
    "from langchain_chroma import Chroma\n",
    "import chromadb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "75c1e7c1-ea24-482f-91a5-7762775d66df",
   "metadata": {},
   "outputs": [],
   "source": [
    "llm = RagLLM()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2c333359-e75e-47c3-bdcc-d7ea38fe99c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_template = \"\"\"\n",
    "你是企业员工助手，熟悉公司考勤和报销标准等规章制度，需要根据提供的上下文信息context来回答员工的提问。\\\n",
    "请直接回答问题，如果上下文信息context没有和问题相关的信息，请直接回答[不知道,请咨询HR] \\\n",
    "问题：{question} \n",
    "\"{context}\"\n",
    "回答：\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e7bcbeb3-f6f9-4fb4-ae69-0be904aeeb52",
   "metadata": {},
   "outputs": [],
   "source": [
    "chroma_client = chromadb.HttpClient(host=\"localhost\", port=8000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "310d9b5b-4a85-42f8-9729-1ae82bbfb2af",
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding_model = RagEmbedding()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "cb34e65f-3cbe-4ea4-a6ba-95ffe7d27969",
   "metadata": {},
   "outputs": [],
   "source": [
    "zhidu_db = Chroma(\"zhidu_db\", \n",
    "                  embedding_model.get_embedding_fun(), \n",
    "                  client=chroma_client)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "bd1ceeac-bb18-4c36-a564-3f5eed526c73",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_rag_pipline_without_stream(query, k=3):\n",
    "    related_docs = zhidu_db.similarity_search(query, k=k)\n",
    "    context_list = [f\"上下文{i+1}: {doc.page_content} \\n\" \\\n",
    "                         for i, doc in enumerate(related_docs)]\n",
    "    context = \"\\n\".join(context_list)\n",
    "\n",
    "    llm_prompt = prompt_template.replace(\"{question}\", query).replace(\"{context}\", context)\n",
    "    response = llm(llm_prompt, stream=False)\n",
    "    return response, context_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7eb14df3-4263-4cb5-b210-853cedc25811",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc9d2988-ff64-4c87-950b-88bc11b2ecda",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "2951f8a9-9bbf-4196-8a68-2369bc8ba4f7",
   "metadata": {},
   "source": [
    "### 构建评估数据集\n",
    "- 问题\n",
    "- 标准答案\n",
    "- 上下文信息\n",
    "- 生成的答案"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7cb20280-ad60-4b1a-8a04-e289dbf6ce4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "questions = [\n",
    "    \"伙食补助费标准是什么?\",\n",
    "    \"出差可以买意外保险吗？需要自己购买吗\",\n",
    "]\n",
    "ground_truths = [\n",
    "    \"伙食补助费标准: 西藏、青海、新疆 120元/人、天 其他省份 100元/人、天\",\n",
    "    \"出差可以购买交通意外保险，由单位统一购买，不再重复购买\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "ff90239d-ab7d-49d1-b546-5b739bde4bef",
   "metadata": {},
   "outputs": [],
   "source": [
    "answers = []\n",
    "contexts = []\n",
    "\n",
    "for query in questions:\n",
    "    response, context_list = run_rag_pipline_without_stream(query, k=3)\n",
    "    answers.append(response)\n",
    "    contexts.append(context_list)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "baa8cfea-77d2-4afa-8d02-2e09210fa249",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "9eac7265-1f35-48aa-8b0e-a23a0d2f2b5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = {\n",
    "    \"question\": questions,\n",
    "    \"answer\": answers,\n",
    "    \"contexts\": contexts,\n",
    "    \"ground_truth\": ground_truths\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "da79fcfc-9cba-406a-b7cb-ff4d741eb653",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = Dataset.from_dict(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fb4de6e-bdfd-4657-9012-edb75bb350c7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "13a3252a-7c6f-4fe6-a322-ee3c280d429d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ragas.metrics import (\n",
    "    faithfulness,\n",
    "    answer_relevancy,\n",
    "    context_recall,\n",
    "    context_precision,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "44b578e2-d9e0-4fe0-be65-f10e2e30921f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ragas import evaluate\n",
    "from ragas import RunConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "79b13d81-2e58-4afe-8923-9e6bd8bfc4b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_llm = QwenLLM()\n",
    "embedding_model = RagEmbedding()\n",
    "eval_embedding_fn = embedding_model.get_embedding_fun()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "56bc2272-d0dd-4656-bbb8-24e895df3b1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = RunConfig(timeout=1200, log_tenacity=True) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "d9624c72-f43d-4d06-a146-e7828e9afe26",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluating: 100%|██████████| 8/8 [02:47<00:00, 20.97s/it]\n"
     ]
    }
   ],
   "source": [
    "result = evaluate(\n",
    "    dataset = dataset,\n",
    "    llm=eval_llm,\n",
    "    embeddings=eval_embedding_fn,\n",
    "    metrics=[\n",
    "        context_precision,\n",
    "        context_recall,\n",
    "        faithfulness,\n",
    "        answer_relevancy,\n",
    "    ],\n",
    "    raise_exceptions=True,\n",
    "    run_config=config\n",
    ")\n",
    "\n",
    "df = result.to_pandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "771999a6-09e2-4127-bc7c-4a8e81f9c972",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>question</th>\n",
       "      <th>answer</th>\n",
       "      <th>contexts</th>\n",
       "      <th>ground_truth</th>\n",
       "      <th>context_precision</th>\n",
       "      <th>context_recall</th>\n",
       "      <th>faithfulness</th>\n",
       "      <th>answer_relevancy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>伙食补助费标准是什么?</td>\n",
       "      <td>伙食补助费的标准为：西藏、青海、新疆地区120元/人、天；其他省份100元/人、天。</td>\n",
       "      <td>[上下文1: &lt;table&gt;&lt;caption&gt;伙食补助费参考以下标准：&lt;/caption&gt;\\...</td>\n",
       "      <td>伙食补助费标准: 西藏、青海、新疆 120元/人、天 其他省份 100元/人、天</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.988139</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>出差可以买意外保险吗？需要自己购买吗</td>\n",
       "      <td>可以购买意外保险，每人次可以购买交通意外保险一份。如果公司已经统一购买了交通意外保险，则不需...</td>\n",
       "      <td>[上下文1: 享受副部级待遇及以上人员出差，因工作需要，随行一人可乘坐同等级交通工具；乘坐飞...</td>\n",
       "      <td>出差可以购买交通意外保险，由单位统一购买，不再重复购买</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.8</td>\n",
       "      <td>0.763618</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "             question                                             answer  \\\n",
       "0         伙食补助费标准是什么?         伙食补助费的标准为：西藏、青海、新疆地区120元/人、天；其他省份100元/人、天。   \n",
       "1  出差可以买意外保险吗？需要自己购买吗  可以购买意外保险，每人次可以购买交通意外保险一份。如果公司已经统一购买了交通意外保险，则不需...   \n",
       "\n",
       "                                            contexts  \\\n",
       "0  [上下文1: <table><caption>伙食补助费参考以下标准：</caption>\\...   \n",
       "1  [上下文1: 享受副部级待遇及以上人员出差，因工作需要，随行一人可乘坐同等级交通工具；乘坐飞...   \n",
       "\n",
       "                               ground_truth  context_precision  \\\n",
       "0  伙食补助费标准: 西藏、青海、新疆 120元/人、天 其他省份 100元/人、天                1.0   \n",
       "1               出差可以购买交通意外保险，由单位统一购买，不再重复购买                1.0   \n",
       "\n",
       "   context_recall  faithfulness  answer_relevancy  \n",
       "0             1.0           1.0          0.988139  \n",
       "1             1.0           0.8          0.763618  "
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "llm"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
